Twin Delayed DDPG (TD3) — from scratch in PyTorch#
TD3 (Fujimoto, van Hoof, Meger, 2018) is a deterministic actor-critic algorithm for continuous control. It improves DDPG with three small but crucial modifications:
Twin critics: learn two Q-functions and use the minimum in the bootstrap target.
Target policy smoothing: add clipped noise to the target action when computing the target Q.
Delayed policy updates: update the actor (and target networks) less often than the critics.
In this notebook we:
write the TD3 update equations precisely (LaTeX)
implement TD3 at a low level in PyTorch (no RL libraries)
train on a Gymnasium environment and plot episodic returns (Plotly)
Learning goals#
Understand why DDPG overestimates and how TD3 fixes it
Implement replay buffer + target networks + twin critics + delayed updates
Train a working agent and visualize learning curves
Prerequisites#
Basic PyTorch (modules, optimizers, autograd)
Q-learning / bootstrapping and the Bellman equation
Actor-critic idea (policy + value function)
Continuous action spaces (e.g., Pendulum)
import copy
import time
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import os
import plotly.io as pio
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
TORCH_AVAILABLE = True
except Exception as e:
TORCH_AVAILABLE = False
_TORCH_IMPORT_ERROR = e
try:
import gymnasium as gym
GYM_AVAILABLE = True
except Exception as e:
GYM_AVAILABLE = False
_GYM_IMPORT_ERROR = e
pio.templates.default = 'plotly_white'
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)
SEED = 42
rng = np.random.default_rng(SEED)
# --- Run configuration ---
FAST_RUN = True
ENV_ID = 'Pendulum-v1'
TOTAL_TIMESTEPS = 10_000 if FAST_RUN else 200_000
START_STEPS = 1_000 if FAST_RUN else 10_000
UPDATE_AFTER = 1_000 if FAST_RUN else 10_000
BATCH_SIZE = 256
BUFFER_SIZE = 200_000
# TD3 hyperparameters
GAMMA = 0.99
TAU = 0.005
ACTOR_LR = 1e-3
CRITIC_LR = 1e-3
POLICY_DELAY = 2
TARGET_POLICY_NOISE = 0.2
TARGET_NOISE_CLIP = 0.5
EXPLORATION_NOISE = 0.1
HIDDEN_SIZES = (256, 256)
1) TD3: the exact updates (twin critics + delayed actor)#
We use:
deterministic policy (actor) \(a = \pi_\phi(s)\)
two critics \(Q_{\theta_1}(s,a)\) and \(Q_{\theta_2}(s,a)\)
target networks \((\phi', \theta_1', \theta_2')\) updated by Polyak averaging
Given a transition \((s,a,r,s',\text{terminal})\) sampled from the replay buffer, TD3 builds the target in three steps.
1. Target policy smoothing#
TD3 does not evaluate the target critics at the raw target action \(\pi_{\phi'}(s')\). Instead it adds clipped Gaussian noise:
Intuition: this makes the target Q-value less sensitive to small action errors and prevents the critic from exploiting sharp, unrealistic peaks in \(Q\).
2. Twin critics (min target)#
Compute both target Q-values and take the minimum:
Each critic minimizes an MSE to this same target:
Taking the minimum is a simple bias-reduction trick: it turns DDPG’s optimistic target into a more conservative estimate, reducing overestimation error.
3. Delayed policy updates#
The critics are updated every gradient step. The actor is updated only every \(d\) critic updates (e.g. \(d=2\)):
In code we minimize the negative:
When (and only when) we update the actor, we also update all target networks with Polyak averaging:
Delaying the actor update lets the critics move closer to their fixed point, so the actor sees a less noisy / less biased gradient.
2) Implementation roadmap#
We will implement TD3 as a small set of building blocks:
Gymnasium environment helpers (reset/step API differences)
Replay buffer (NumPy storage, PyTorch sampling)
Actor network \(\pi_\phi(s)\)
Twin critic networks \(Q_{\theta_1}(s,a), Q_{\theta_2}(s,a)\)
TD3 update step (critic update every step, actor + target update every
POLICY_DELAYsteps)Training loop + Plotly learning curve
def set_global_seeds(seed: int) -> None:
np.random.seed(seed)
if TORCH_AVAILABLE:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def env_reset(env, seed=None):
out = env.reset(seed=seed) if seed is not None else env.reset()
if isinstance(out, tuple) and len(out) == 2:
obs, info = out
return obs, info
return out, {}
def env_step(env, action):
out = env.step(action)
if isinstance(out, tuple) and len(out) == 5:
next_obs, reward, terminated, truncated, info = out
done = bool(terminated or truncated)
terminal = bool(terminated) # time-limit truncation is not a terminal state
return next_obs, float(reward), done, terminal, info
if isinstance(out, tuple) and len(out) == 4:
next_obs, reward, done, info = out
terminal = bool(done)
return next_obs, float(reward), bool(done), terminal, info
raise ValueError('Unexpected env.step(...) output format')
if not TORCH_AVAILABLE:
raise RuntimeError(f'PyTorch import failed: {_TORCH_IMPORT_ERROR}')
if not GYM_AVAILABLE:
raise RuntimeError(f'Gymnasium import failed: {_GYM_IMPORT_ERROR}')
set_global_seeds(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
env = gym.make(ENV_ID)
obs, _ = env_reset(env, seed=SEED)
obs_dim = int(np.prod(env.observation_space.shape))
act_dim = int(np.prod(env.action_space.shape))
action_low = env.action_space.low.astype(np.float32)
action_high = env.action_space.high.astype(np.float32)
print('env:', ENV_ID)
print('obs_dim:', obs_dim)
print('act_dim:', act_dim)
print('action_low:', action_low)
print('action_high:', action_high)
print('device:', device)
env: Pendulum-v1
obs_dim: 3
act_dim: 1
action_low: [-2.]
action_high: [2.]
device: cpu
/home/tempa/miniconda3/lib/python3.12/site-packages/torch/cuda/__init__.py:174: UserWarning:
CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
class ReplayBuffer:
def __init__(self, obs_dim: int, act_dim: int, size: int, seed: int, device: torch.device):
self.obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
self.next_obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
self.act_buf = np.zeros((size, act_dim), dtype=np.float32)
self.rew_buf = np.zeros((size, 1), dtype=np.float32)
self.done_buf = np.zeros((size, 1), dtype=np.float32)
self.max_size = int(size)
self.ptr = 0
self.size = 0
self.rng = np.random.default_rng(seed)
self.device = device
def add(self, obs, act, rew: float, next_obs, terminal: bool) -> None:
self.obs_buf[self.ptr] = np.asarray(obs, dtype=np.float32).reshape(-1)
self.next_obs_buf[self.ptr] = np.asarray(next_obs, dtype=np.float32).reshape(-1)
self.act_buf[self.ptr] = np.asarray(act, dtype=np.float32).reshape(-1)
self.rew_buf[self.ptr] = float(rew)
self.done_buf[self.ptr] = 1.0 if terminal else 0.0
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size: int):
if self.size < batch_size:
raise ValueError(f'Not enough samples: size={self.size}, batch_size={batch_size}')
idxs = self.rng.integers(0, self.size, size=batch_size)
obs = torch.as_tensor(self.obs_buf[idxs], device=self.device)
act = torch.as_tensor(self.act_buf[idxs], device=self.device)
rew = torch.as_tensor(self.rew_buf[idxs], device=self.device)
next_obs = torch.as_tensor(self.next_obs_buf[idxs], device=self.device)
done = torch.as_tensor(self.done_buf[idxs], device=self.device)
return obs, act, rew, next_obs, done
def mlp(layer_sizes, activation=nn.ReLU, output_activation=nn.Identity):
layers = []
for i in range(len(layer_sizes) - 1):
act = activation if i < len(layer_sizes) - 2 else output_activation
layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
layers.append(act())
return nn.Sequential(*layers)
class Actor(nn.Module):
def __init__(self, obs_dim: int, act_dim: int, hidden_sizes, action_low, action_high):
super().__init__()
self.net = mlp([obs_dim, *hidden_sizes, act_dim], activation=nn.ReLU, output_activation=nn.Identity)
action_low_t = torch.as_tensor(action_low, dtype=torch.float32)
action_high_t = torch.as_tensor(action_high, dtype=torch.float32)
self.register_buffer('action_scale', (action_high_t - action_low_t) / 2.0)
self.register_buffer('action_bias', (action_high_t + action_low_t) / 2.0)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
a = torch.tanh(self.net(obs))
return a * self.action_scale + self.action_bias
class QNetwork(nn.Module):
def __init__(self, obs_dim: int, act_dim: int, hidden_sizes):
super().__init__()
self.net = mlp([obs_dim + act_dim, *hidden_sizes, 1], activation=nn.ReLU, output_activation=nn.Identity)
def forward(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
x = torch.cat([obs, act], dim=-1)
return self.net(x)
class TwinCritic(nn.Module):
def __init__(self, obs_dim: int, act_dim: int, hidden_sizes):
super().__init__()
self.q1 = QNetwork(obs_dim, act_dim, hidden_sizes)
self.q2 = QNetwork(obs_dim, act_dim, hidden_sizes)
def forward(self, obs: torch.Tensor, act: torch.Tensor):
return self.q1(obs, act), self.q2(obs, act)
def q1_only(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
return self.q1(obs, act)
class TD3Agent:
def __init__(
self,
obs_dim: int,
act_dim: int,
hidden_sizes,
action_low,
action_high,
device: torch.device,
gamma: float = 0.99,
tau: float = 0.005,
actor_lr: float = 1e-3,
critic_lr: float = 1e-3,
policy_delay: int = 2,
target_policy_noise: float = 0.2,
target_noise_clip: float = 0.5,
):
self.device = device
self.actor = Actor(obs_dim, act_dim, hidden_sizes, action_low, action_high).to(device)
self.actor_target = copy.deepcopy(self.actor).to(device)
self.critic = TwinCritic(obs_dim, act_dim, hidden_sizes).to(device)
self.critic_target = copy.deepcopy(self.critic).to(device)
self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
self.gamma = float(gamma)
self.tau = float(tau)
self.policy_delay = int(policy_delay)
self.target_policy_noise = float(target_policy_noise)
self.target_noise_clip = float(target_noise_clip)
self.action_low_t = torch.as_tensor(action_low, dtype=torch.float32, device=device)
self.action_high_t = torch.as_tensor(action_high, dtype=torch.float32, device=device)
self.total_it = 0
# Targets start identical to online nets
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_target.load_state_dict(self.critic.state_dict())
@torch.no_grad()
def select_action(self, obs, noise_scale: float = 0.0):
obs_t = torch.as_tensor(np.asarray(obs, dtype=np.float32).reshape(1, -1), device=self.device)
action = self.actor(obs_t).cpu().numpy().reshape(-1)
if noise_scale and noise_scale > 0:
action = action + np.random.normal(0.0, noise_scale, size=action.shape).astype(np.float32)
action = np.clip(action, self.action_low_t.cpu().numpy(), self.action_high_t.cpu().numpy())
return action
def _soft_update_(self, source: nn.Module, target: nn.Module) -> None:
with torch.no_grad():
for p, p_targ in zip(source.parameters(), target.parameters()):
p_targ.data.mul_(1.0 - self.tau)
p_targ.data.add_(self.tau * p.data)
def train_step(self, replay_buffer: ReplayBuffer, batch_size: int):
self.total_it += 1
obs, act, rew, next_obs, done = replay_buffer.sample(batch_size)
# --- Critic update (every step) ---
with torch.no_grad():
noise = torch.randn_like(act) * self.target_policy_noise
noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
next_action = self.actor_target(next_obs) + noise
next_action = torch.max(torch.min(next_action, self.action_high_t), self.action_low_t)
target_q1, target_q2 = self.critic_target(next_obs, next_action)
target_q = torch.min(target_q1, target_q2)
y = rew + (1.0 - done) * self.gamma * target_q
current_q1, current_q2 = self.critic(obs, act)
critic_loss = F.mse_loss(current_q1, y) + F.mse_loss(current_q2, y)
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
info = {'critic_loss': float(critic_loss.item())}
# --- Delayed actor + target updates ---
if self.total_it % self.policy_delay == 0:
actor_loss = -self.critic.q1_only(obs, self.actor(obs)).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
info['actor_loss'] = float(actor_loss.item())
self._soft_update_(self.critic, self.critic_target)
self._soft_update_(self.actor, self.actor_target)
return info
replay = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=BUFFER_SIZE, seed=SEED, device=device)
agent = TD3Agent(
obs_dim=obs_dim,
act_dim=act_dim,
hidden_sizes=HIDDEN_SIZES,
action_low=action_low,
action_high=action_high,
device=device,
gamma=GAMMA,
tau=TAU,
actor_lr=ACTOR_LR,
critic_lr=CRITIC_LR,
policy_delay=POLICY_DELAY,
target_policy_noise=TARGET_POLICY_NOISE,
target_noise_clip=TARGET_NOISE_CLIP,
)
episode_returns = []
episode_lengths = []
critic_losses = []
actor_losses = []
obs, _ = env_reset(env, seed=SEED)
ep_return = 0.0
ep_len = 0
t0 = time.time()
for t in range(TOTAL_TIMESTEPS):
if t < START_STEPS:
action = env.action_space.sample()
else:
action = agent.select_action(obs, noise_scale=EXPLORATION_NOISE)
next_obs, reward, done, terminal, _info = env_step(env, action)
replay.add(obs, action, reward, next_obs, terminal)
obs = next_obs
ep_return += reward
ep_len += 1
if t >= UPDATE_AFTER:
train_info = agent.train_step(replay, batch_size=BATCH_SIZE)
critic_losses.append(train_info['critic_loss'])
if 'actor_loss' in train_info:
actor_losses.append(train_info['actor_loss'])
if done:
episode_returns.append(ep_return)
episode_lengths.append(ep_len)
if len(episode_returns) % 5 == 0 or not FAST_RUN:
elapsed = time.time() - t0
print(
f"Episode {len(episode_returns):4d} | return {ep_return:9.1f} | len {ep_len:3d} | "
f"t {t + 1:6d}/{TOTAL_TIMESTEPS} | elapsed {elapsed:6.1f}s"
)
obs, _ = env_reset(env)
ep_return = 0.0
ep_len = 0
env.close()
print('episodes:', len(episode_returns))
print('last return:', episode_returns[-1] if episode_returns else None)
Episode 5 | return -1301.5 | len 200 | t 1000/10000 | elapsed 0.0s
Episode 10 | return -1504.4 | len 200 | t 2000/10000 | elapsed 11.3s
Episode 15 | return -1295.4 | len 200 | t 3000/10000 | elapsed 24.7s
Episode 20 | return -1086.4 | len 200 | t 4000/10000 | elapsed 38.1s
Episode 25 | return -817.4 | len 200 | t 5000/10000 | elapsed 51.8s
Episode 30 | return -1039.3 | len 200 | t 6000/10000 | elapsed 65.4s
Episode 35 | return -526.4 | len 200 | t 7000/10000 | elapsed 78.9s
Episode 40 | return -130.1 | len 200 | t 8000/10000 | elapsed 92.5s
Episode 45 | return -244.9 | len 200 | t 9000/10000 | elapsed 106.1s
Episode 50 | return -247.3 | len 200 | t 10000/10000 | elapsed 119.8s
episodes: 50
last return: -247.34952581981875
# Plot episodic returns
df = pd.DataFrame(
{
'episode': np.arange(1, len(episode_returns) + 1),
'return': episode_returns,
'length': episode_lengths,
}
)
window = min(10, max(1, len(df)))
df['return_ma'] = df['return'].rolling(window=window, min_periods=1).mean()
fig = go.Figure()
fig.add_trace(go.Scatter(x=df['episode'], y=df['return'], mode='lines+markers', name='Return'))
fig.add_trace(go.Scatter(x=df['episode'], y=df['return_ma'], mode='lines', name=f'{window}-episode MA'))
fig.update_layout(
title=f'TD3 training on {ENV_ID}: episodic return',
xaxis_title='Episode',
yaxis_title='Return',
)
fig.show()
Notes, diagnostics, and common pitfalls#
Terminal masking: for time-limit truncation, you typically still bootstrap, so we mask only
terminated(Gymnasium) rather thantruncated.Twin critics: the key is using the minimum only in the bootstrap target \(y\) (not necessarily everywhere).
Delayed updates: do not update the actor every step; it should be updated every
POLICY_DELAYcritic updates.Target policy smoothing: the noise added to target actions is separate from exploration noise.
Exploration: TD3 is deterministic; you must add noise to actions during data collection.
Stable-Baselines TD3 (reference implementation)#
Stable-Baselines3 (SB3) includes a PyTorch TD3 implementation: https://stable-baselines3.readthedocs.io/en/master/modules/td3.html
This is useful as a reference and a quick way to validate your intuition against a well-tested baseline.
If you want to run it locally:
pip install stable-baselines3
If you have SB3 installed, a minimal training script looks like:
import gymnasium as gym
import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise
env = gym.make('Pendulum-v1')
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3(
policy='MlpPolicy',
env=env,
action_noise=action_noise,
verbose=1,
)
model.learn(total_timesteps=100_000)
At the end of this notebook we summarize SB3’s TD3 hyperparameters.
Stable-Baselines3 TD3 hyperparameters (glossary + defaults)#
Web research source: https://stable-baselines3.readthedocs.io/en/master/modules/td3.html
Constructor signature (defaults):
TD3(policy, env, learning_rate=0.001, buffer_size=1000000, learning_starts=100, batch_size=256, tau=0.005, gamma=0.99, train_freq=1, gradient_steps=1, action_noise=None, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, n_steps=1, policy_delay=2, target_policy_noise=0.2, target_noise_clip=0.5, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)
Glossary:
policy: policy class/name (e.g.,MlpPolicy,CnnPolicy).env: environment instance or env ID string.learning_rate(default1e-3): Adam learning rate (SB3 uses the same LR for actor and critics).buffer_size(default1_000_000): replay buffer capacity.learning_starts(default100): number of environment steps collected before training begins.batch_size(default256): mini-batch size sampled from replay.tau(default0.005): Polyak coefficient \(\tau\) for target network updates.gamma(default0.99): discount factor \(\gamma\).train_freq(default1): how often to train (steps), or a tuple like(n, 'step')/(n, 'episode').gradient_steps(default1): gradient updates per training iteration.action_noise(defaultNone): exploration noise used when collecting data (e.g. Gaussian or OU noise).policy_delay(default2): actor/target update period \(d\) (critics update every step).target_policy_noise(default0.2): \(\sigma\) in target policy smoothing.target_noise_clip(default0.5): \(c\) in target policy smoothing (clip range).replay_buffer_class(defaultNone): custom replay buffer class.replay_buffer_kwargs(defaultNone): kwargs passed to the replay buffer.optimize_memory_usage(defaultFalse): memory-efficient replay buffer variant.n_steps(default1): n-step returns (when >1 uses an n-step replay buffer).stats_window_size(default100): logging window size (episodes averaged).tensorboard_log(defaultNone): TensorBoard log directory.policy_kwargs(defaultNone): policy/network architecture options.verbose(default0): verbosity (0/1/2).seed(defaultNone): RNG seed.device(default'auto'): device selection (CPU/GPU)._init_setup_model(defaultTrue): whether to build networks at init.
References#
Fujimoto, van Hoof, Meger (2018): Addressing Function Approximation Error in Actor-Critic Methods (TD3)
Stable-Baselines3 docs / source code (TD3)